{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Copyright (c) Microsoft Corporation. All rights reserved.\n",
    "\n",
    "Licensed under the MIT License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GenSen with Pytorch\n",
    "In this tutorial, you will train a GenSen model for the sentence similarity task. We use the [SNLI](https://nlp.stanford.edu/projects/snli/) dataset in this example. For a more detailed walkthrough about data processing jump to [SNLI Data Prep](../01-prep-data/snli.ipynb). A quickstart version of this notebook can be found [here](../00-quick-start/)\n",
    "\n",
    "## Notes:\n",
    "The model training part of this notebook can only run on a GPU machine. The running time shown in the notebook is on a Standard_NC6 Azure VM with 1 NVIDIA Tesla K80 GPU and 12 GB GPU memory. See the [README](README.md) for more details of the running time."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "### What is GenSen?\n",
    "\n",
    "GenSen is a technique to learn general purpose, fixed-length representations of sentences via multi-task training. GenSen model combines the benefits of diverse sentence-representation learning objectives into a single multi-task framework. \"This is the first large-scale reusable sentence representation model obtained by combining a set of training objectives with the level of diversity explored here, i.e. multi-lingual NMT, natural language inference, constituency parsing and skip-thought vectors.\" [\\[1\\]](#References) These representations are useful for transfer and low-resource learning. GenSen is trained on several data sources with multiple training objectives on over 100 milion sentences.\n",
    "\n",
    "The GenSen model is most similar to that of Luong et al. (2015) [\\[4\\]](#References), who train a many-to-many **sequence-to-sequence** model on a diverse set of weakly related tasks that includes machine translation, constituency parsing, image captioning, sequence autoencoding, and intra-sentence skip-thoughts. However, there are two key differences. \"First, like McCann et al. (2017) [\\[5\\]](#References), their use of an attention mechanism prevents learning a ﬁxed-length vector representation for a sentence. Second, their work aims for improvements on the same tasks on which the model is trained, as opposed to learning re-usable sentence representations that transfer elsewhere.\" [\\[1\\]](#References)\n",
    "\n",
    "### Why GenSen?\n",
    "\n",
    "GenSen model performs the state-of-the-art results on multiple datasets, such as MRPC, SICK-R, SICK-E and STS, for sentence similarity. The reported results are as follows compared with other models [\\[3\\]](#References):\n",
    "\n",
    "| Model | MRPC | SICK-R | SICK-E | STS |\n",
    "| --- | --- | --- | --- | --- |\n",
    "| GenSen (Subramanian et al., 2018) | 78.6/84.4 | 0.888 | 87.8 | 78.9/78.6 |\n",
    "| [InferSent](https://arxiv.org/abs/1705.02364) (Conneau et al., 2017) | 76.2/83.1 | 0.884 | 86.3 | 75.8/75.5 |\n",
    "| [TF-KLD](https://www.aclweb.org/anthology/D13-1090) (Ji and Eisenstein, 2013) | 80.4/85.9 | - | - | - |"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Outline\n",
    "This notebook is organized as follows:\n",
    "\n",
    "1. Data preparation and inspection.\n",
    "2. Model training and prediction.\n",
    "\n",
    "For a more detailed deep dive of the Gensen model checkout the [Gensen Deep Dive Notebook](gensen_aml_deep_dive.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Global Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "System version: 3.6.8 |Anaconda, Inc.| (default, Dec 30 2018, 01:22:34) \n",
      "[GCC 7.3.0]\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "\n",
    "import os\n",
    "import papermill as pm\n",
    "import scrapbook as sb\n",
    "\n",
    "from utils_nlp.dataset.preprocess import to_lowercase, to_nltk_tokens\n",
    "from utils_nlp.dataset import snli, preprocess\n",
    "from utils_nlp.models.pretrained_embeddings.glove import download_and_extract\n",
    "from utils_nlp.dataset import Split\n",
    "from examples.sentence_similarity.gensen_wrapper import GenSenClassifier\n",
    "\n",
    "print(\"System version: {}\".format(sys.version))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "max_epoch = None\n",
    "config_filepath = 'gensen_config.json'\n",
    "base_data_path = '../../data'\n",
    "nrows = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Data Preparation and inspection"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The [SNLI](https://nlp.stanford.edu/projects/snli/) corpus (version 1.0) is a collection of 570k human-written English sentence pairs manually labeled for balanced classification with the labels entailment, contradiction, and neutral, supporting the task of natural language inference (NLI), also known as recognizing textual entailment (RTE). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 Load the dataset\n",
    "\n",
    "We provide a function load_pandas_df which does the following\n",
    "\n",
    "* Downloads the SNLI zipfile at the specified directory location\n",
    "* Extracts the file based on the specified split\n",
    "* Loads the split as a pandas dataframe The zipfile contains the following files:\n",
    "    * snli_1.0_dev.txt\n",
    "    * snli_1.0_train.txt\n",
    "    * snli_1.0_test.tx\n",
    "    * snli_1.0_dev.jsonl\n",
    "    * snli_1.0_train.jsonl\n",
    "    * snli_1.0_test.jsonl\n",
    "    \n",
    "The loader defaults to reading from the .txt file; however, you can change this to .jsonl by setting the optional file_type parameter when calling the function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>gold_label</th>\n",
       "      <th>sentence1_binary_parse</th>\n",
       "      <th>sentence2_binary_parse</th>\n",
       "      <th>sentence1_parse</th>\n",
       "      <th>sentence2_parse</th>\n",
       "      <th>sentence1</th>\n",
       "      <th>sentence2</th>\n",
       "      <th>captionID</th>\n",
       "      <th>pairID</th>\n",
       "      <th>label1</th>\n",
       "      <th>label2</th>\n",
       "      <th>label3</th>\n",
       "      <th>label4</th>\n",
       "      <th>label5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>neutral</td>\n",
       "      <td>( ( ( A person ) ( on ( a horse ) ) ) ( ( jump...</td>\n",
       "      <td>( ( A person ) ( ( is ( ( training ( his horse...</td>\n",
       "      <td>(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN o...</td>\n",
       "      <td>(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) ...</td>\n",
       "      <td>A person on a horse jumps over a broken down a...</td>\n",
       "      <td>A person is training his horse for a competition.</td>\n",
       "      <td>3416050480.jpg#4</td>\n",
       "      <td>3416050480.jpg#4r1n</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>contradiction</td>\n",
       "      <td>( ( ( A person ) ( on ( a horse ) ) ) ( ( jump...</td>\n",
       "      <td>( ( A person ) ( ( ( ( is ( at ( a diner ) ) )...</td>\n",
       "      <td>(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN o...</td>\n",
       "      <td>(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) ...</td>\n",
       "      <td>A person on a horse jumps over a broken down a...</td>\n",
       "      <td>A person is at a diner, ordering an omelette.</td>\n",
       "      <td>3416050480.jpg#4</td>\n",
       "      <td>3416050480.jpg#4r1c</td>\n",
       "      <td>contradiction</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>entailment</td>\n",
       "      <td>( ( ( A person ) ( on ( a horse ) ) ) ( ( jump...</td>\n",
       "      <td>( ( A person ) ( ( ( ( is outdoors ) , ) ( on ...</td>\n",
       "      <td>(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN o...</td>\n",
       "      <td>(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) ...</td>\n",
       "      <td>A person on a horse jumps over a broken down a...</td>\n",
       "      <td>A person is outdoors, on a horse.</td>\n",
       "      <td>3416050480.jpg#4</td>\n",
       "      <td>3416050480.jpg#4r1e</td>\n",
       "      <td>entailment</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>neutral</td>\n",
       "      <td>( Children ( ( ( smiling and ) waving ) ( at c...</td>\n",
       "      <td>( They ( are ( smiling ( at ( their parents ) ...</td>\n",
       "      <td>(ROOT (NP (S (NP (NNP Children)) (VP (VBG smil...</td>\n",
       "      <td>(ROOT (S (NP (PRP They)) (VP (VBP are) (VP (VB...</td>\n",
       "      <td>Children smiling and waving at camera</td>\n",
       "      <td>They are smiling at their parents</td>\n",
       "      <td>2267923837.jpg#2</td>\n",
       "      <td>2267923837.jpg#2r1n</td>\n",
       "      <td>neutral</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>entailment</td>\n",
       "      <td>( Children ( ( ( smiling and ) waving ) ( at c...</td>\n",
       "      <td>( There ( ( are children ) present ) )</td>\n",
       "      <td>(ROOT (NP (S (NP (NNP Children)) (VP (VBG smil...</td>\n",
       "      <td>(ROOT (S (NP (EX There)) (VP (VBP are) (NP (NN...</td>\n",
       "      <td>Children smiling and waving at camera</td>\n",
       "      <td>There are children present</td>\n",
       "      <td>2267923837.jpg#2</td>\n",
       "      <td>2267923837.jpg#2r1e</td>\n",
       "      <td>entailment</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      gold_label                             sentence1_binary_parse  \\\n",
       "0        neutral  ( ( ( A person ) ( on ( a horse ) ) ) ( ( jump...   \n",
       "1  contradiction  ( ( ( A person ) ( on ( a horse ) ) ) ( ( jump...   \n",
       "2     entailment  ( ( ( A person ) ( on ( a horse ) ) ) ( ( jump...   \n",
       "3        neutral  ( Children ( ( ( smiling and ) waving ) ( at c...   \n",
       "4     entailment  ( Children ( ( ( smiling and ) waving ) ( at c...   \n",
       "\n",
       "                              sentence2_binary_parse  \\\n",
       "0  ( ( A person ) ( ( is ( ( training ( his horse...   \n",
       "1  ( ( A person ) ( ( ( ( is ( at ( a diner ) ) )...   \n",
       "2  ( ( A person ) ( ( ( ( is outdoors ) , ) ( on ...   \n",
       "3  ( They ( are ( smiling ( at ( their parents ) ...   \n",
       "4             ( There ( ( are children ) present ) )   \n",
       "\n",
       "                                     sentence1_parse  \\\n",
       "0  (ROOT (S (NP (NP (DT A) (NN person)) (PP (IN o...   \n",
       "1  (ROOT (S (NP (NP (DT A) (NN person)) (PP (IN o...   \n",
       "2  (ROOT (S (NP (NP (DT A) (NN person)) (PP (IN o...   \n",
       "3  (ROOT (NP (S (NP (NNP Children)) (VP (VBG smil...   \n",
       "4  (ROOT (NP (S (NP (NNP Children)) (VP (VBG smil...   \n",
       "\n",
       "                                     sentence2_parse  \\\n",
       "0  (ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) ...   \n",
       "1  (ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) ...   \n",
       "2  (ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) ...   \n",
       "3  (ROOT (S (NP (PRP They)) (VP (VBP are) (VP (VB...   \n",
       "4  (ROOT (S (NP (EX There)) (VP (VBP are) (NP (NN...   \n",
       "\n",
       "                                           sentence1  \\\n",
       "0  A person on a horse jumps over a broken down a...   \n",
       "1  A person on a horse jumps over a broken down a...   \n",
       "2  A person on a horse jumps over a broken down a...   \n",
       "3              Children smiling and waving at camera   \n",
       "4              Children smiling and waving at camera   \n",
       "\n",
       "                                           sentence2         captionID  \\\n",
       "0  A person is training his horse for a competition.  3416050480.jpg#4   \n",
       "1      A person is at a diner, ordering an omelette.  3416050480.jpg#4   \n",
       "2                  A person is outdoors, on a horse.  3416050480.jpg#4   \n",
       "3                  They are smiling at their parents  2267923837.jpg#2   \n",
       "4                         There are children present  2267923837.jpg#2   \n",
       "\n",
       "                pairID         label1 label2 label3 label4 label5  \n",
       "0  3416050480.jpg#4r1n        neutral    NaN    NaN    NaN    NaN  \n",
       "1  3416050480.jpg#4r1c  contradiction    NaN    NaN    NaN    NaN  \n",
       "2  3416050480.jpg#4r1e     entailment    NaN    NaN    NaN    NaN  \n",
       "3  2267923837.jpg#2r1n        neutral    NaN    NaN    NaN    NaN  \n",
       "4  2267923837.jpg#2r1e     entailment    NaN    NaN    NaN    NaN  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = snli.load_pandas_df(base_data_path, file_split=Split.TRAIN, nrows=nrows)\n",
    "dev = snli.load_pandas_df(base_data_path, file_split=Split.DEV, nrows=nrows)\n",
    "test = snli.load_pandas_df(base_data_path, file_split=Split.TEST, nrows=nrows)\n",
    "\n",
    "train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 Tokenize\n",
    "\n",
    "We have loaded the dataset into pandas.DataFrame, we now convert sentences to tokens. We also clean the data before tokenizing. This includes dropping unneccessary columns and renaming the relevant columns as score, sentence_1, and sentence_2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def clean_and_tokenize(df):\n",
    "    df = snli.clean_cols(df)\n",
    "    df = snli.clean_rows(df)\n",
    "    df = preprocess.to_lowercase(df)\n",
    "    df = preprocess.to_nltk_tokens(df)\n",
    "    return df\n",
    "\n",
    "train = clean_and_tokenize(train)\n",
    "dev = clean_and_tokenize(dev)\n",
    "test = clean_and_tokenize(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we have the clean pandas dataframes, we do lowercase standardization and tokenization. We use the [NLTK] (https://www.nltk.org/) library for tokenization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>score</th>\n",
       "      <th>sentence1</th>\n",
       "      <th>sentence2</th>\n",
       "      <th>sentence1_tokens</th>\n",
       "      <th>sentence2_tokens</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>neutral</td>\n",
       "      <td>two women are embracing while holding to go pa...</td>\n",
       "      <td>the sisters are hugging goodbye while holding ...</td>\n",
       "      <td>[two, women, are, embracing, while, holding, t...</td>\n",
       "      <td>[the, sisters, are, hugging, goodbye, while, h...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>entailment</td>\n",
       "      <td>two women are embracing while holding to go pa...</td>\n",
       "      <td>two woman are holding packages.</td>\n",
       "      <td>[two, women, are, embracing, while, holding, t...</td>\n",
       "      <td>[two, woman, are, holding, packages, .]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>contradiction</td>\n",
       "      <td>two women are embracing while holding to go pa...</td>\n",
       "      <td>the men are fighting outside a deli.</td>\n",
       "      <td>[two, women, are, embracing, while, holding, t...</td>\n",
       "      <td>[the, men, are, fighting, outside, a, deli, .]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>entailment</td>\n",
       "      <td>two young children in blue jerseys, one with t...</td>\n",
       "      <td>two kids in numbered jerseys wash their hands.</td>\n",
       "      <td>[two, young, children, in, blue, jerseys, ,, o...</td>\n",
       "      <td>[two, kids, in, numbered, jerseys, wash, their...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>neutral</td>\n",
       "      <td>two young children in blue jerseys, one with t...</td>\n",
       "      <td>two kids at a ballgame wash their hands.</td>\n",
       "      <td>[two, young, children, in, blue, jerseys, ,, o...</td>\n",
       "      <td>[two, kids, at, a, ballgame, wash, their, hand...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           score                                          sentence1  \\\n",
       "0        neutral  two women are embracing while holding to go pa...   \n",
       "1     entailment  two women are embracing while holding to go pa...   \n",
       "2  contradiction  two women are embracing while holding to go pa...   \n",
       "3     entailment  two young children in blue jerseys, one with t...   \n",
       "4        neutral  two young children in blue jerseys, one with t...   \n",
       "\n",
       "                                           sentence2  \\\n",
       "0  the sisters are hugging goodbye while holding ...   \n",
       "1                    two woman are holding packages.   \n",
       "2               the men are fighting outside a deli.   \n",
       "3     two kids in numbered jerseys wash their hands.   \n",
       "4           two kids at a ballgame wash their hands.   \n",
       "\n",
       "                                    sentence1_tokens  \\\n",
       "0  [two, women, are, embracing, while, holding, t...   \n",
       "1  [two, women, are, embracing, while, holding, t...   \n",
       "2  [two, women, are, embracing, while, holding, t...   \n",
       "3  [two, young, children, in, blue, jerseys, ,, o...   \n",
       "4  [two, young, children, in, blue, jerseys, ,, o...   \n",
       "\n",
       "                                    sentence2_tokens  \n",
       "0  [the, sisters, are, hugging, goodbye, while, h...  \n",
       "1            [two, woman, are, holding, packages, .]  \n",
       "2     [the, men, are, fighting, outside, a, deli, .]  \n",
       "3  [two, kids, in, numbered, jerseys, wash, their...  \n",
       "4  [two, kids, at, a, ballgame, wash, their, hand...  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dev.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  2. Model application, performance and analysis of the results\n",
    "The model has been implemented as a GenSen class with the specifics hidden inside the fit() method, so that no explicit call is needed. The algorithm operates in three different steps:\n",
    "\n",
    "** Model initialization ** : This is where we tell our class how to train the model. The main parameters to specify are the number of\n",
    "1. config file which contains information about the number of training epochs, the minibatch size etc.\n",
    "2. cache_dir which is the folder where all the data will be saved.\n",
    "3. learning rate for the model\n",
    "4. path to the pretrained embedding vectors.\n",
    "\n",
    "** Model fit ** : This is where we train the model on the data. The method takes two arguments: the training, dev and test set pandas dataframes. Note that the model is trained only on the training set, the test set is used to display the test set accuracy of the trained model, that in turn is an estimation of the generazation capabilities of the algorithm. It is generally useful to look at these quantities to have a first idea of the optimization behaviour.\n",
    "\n",
    "** Model prediction ** : This is where we generate the similarity for a pair of sentences. Once the model has been trained and we are satisfied with its overall accuracy we use the saved model to show the similarity between two provided sentences. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.0 Download pretrained vectors\n",
    "In this example we use gloVe for pretrained embedding vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vector file already exists. No changes made.\n"
     ]
    }
   ],
   "source": [
    "pretrained_embedding_path = download_and_extract(base_data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 Initialize Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "clf = GenSenClassifier(config_file = config_filepath, \n",
    "                       pretrained_embedding_path = pretrained_embedding_path,\n",
    "                       learning_rate = 0.0001, \n",
    "                       cache_dir=base_data_path,\n",
    "                      max_epoch=max_epoch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/anaconda/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/modules/rnn.py:46: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.8 and num_layers=1\n",
      "  \"num_layers={}\".format(dropout, num_layers))\n",
      "../../examples/sentence_similarity/gensen_train.py:431: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.\n",
      "  torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)\n",
      "../../utils_nlp/models/gensen/utils.py:364: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n",
      "  Variable(torch.LongTensor(sorted_src_lens), volatile=True)\n",
      "/data/anaconda/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/functional.py:1332: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n",
      "  warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n",
      "/data/anaconda/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/functional.py:1320: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n",
      "  warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n",
      "../../examples/sentence_similarity/gensen_train.py:523: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.\n",
      "  torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)\n",
      "/data/anaconda/envs/nlp_gpu/lib/python3.6/site-packages/horovod/torch/__init__.py:163: UserWarning: optimizer.step(synchronize=True) called after optimizer.synchronize(). This can cause training slowdown. You may want to consider using optimizer.step(synchronize=False) if you use optimizer.synchronize() in your code.\n",
      "  warnings.warn(\"optimizer.step(synchronize=True) called after \"\n",
      "../../examples/sentence_similarity/gensen_train.py:243: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  f.softmax(class_logits).data.cpu().numpy().argmax(axis=-1)\n",
      "../../examples/sentence_similarity/gensen_train.py:262: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  f.softmax(class_logits).data.cpu().numpy().argmax(axis=-1)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1h 19min 28s, sys: 22min 1s, total: 1h 41min 30s\n",
      "Wall time: 1h 41min 22s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "clf.fit(train, dev, test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 Predict\n",
    "\n",
    "In the predict method we perform Pearson's Correlation computation [\\[2\\]](#References) on the outputs of the model. The predictions of the model can be further improved by hyperparameter tuning which we walk through in the other example [here](gensen_aml_deep_dive.ipynb). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "******** Similarity Score for sentences **************\n",
      "          0         1\n",
      "0  1.000000  0.966793\n",
      "1  0.966793  1.000000\n"
     ]
    }
   ],
   "source": [
    "sentences = [\n",
    "        'The sky is blue and beautiful',\n",
    "        'Love this blue and beautiful sky!'\n",
    "    ]\n",
    "\n",
    "results = clf.predict(sentences)\n",
    "print(\"******** Similarity Score for sentences **************\")\n",
    "print(results)\n",
    "\n",
    "# Record results with scrapbook for tests\n",
    "sb.glue(\"results\", results.to_dict())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "1. Subramanian, Sandeep and Trischler, Adam and Bengio, Yoshua and Pal, Christopher J, [*Learning general purpose distributed sentence representations via large scale multi-task learning*](https://arxiv.org/abs/1804.00079), ICLR, 2018.\n",
    "2. Pearson's Correlation Coefficient. url: https://en.wikipedia.org/wiki/Pearson_correlation_coefficient\n",
    "3. Semantic textual similarity. url: http://nlpprogress.com/english/semantic_textual_similarity.html\n",
    "4. Minh-Thang Luong, Quoc V Le, Ilya Sutskever, Oriol Vinyals, and Lukasz Kaiser. [*Multi-task sequence to sequence learning*](https://arxiv.org/abs/1511.06114), 2015.\n",
    "5. Bryan McCann, James Bradbury, Caiming Xiong, and Richard Socher. [*Learned in translation: Contextualized word vectors](https://arxiv.org/abs/1708.00107), 2017. "
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "nlp_gpu",
   "language": "python",
   "name": "nlp_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
